import sys

import torch
from torch import optim
import torch.nn as nn
from torch.autograd import Variable

from mmd import rbf_mmd2


def get_reference_(dataloader, size=10000, dataset='CaliH', epochs=5, netD_output_dim=16, use_GMM=False, device=torch.device('cuda')):
    
    if use_GMM:
        print("Using GMM the parametric family.")
        from _Ours_utils import _get_mixture_dictionary
        mixtures =  _get_mixture_dictionary(dataset)

        data = torch.cat([x for x, y in dataloader])

        curr_MMD2, curr_generated = float('inf'), None
        for mixture in tqdm(mixtures, desc=f"Finding the best mixture out of {len(mixtures)} mixtures."):
            generated = torch.from_numpy(mixture.sample(len(data))[0]).float() # the mixture.sample returns [generated, component_labels]
            mmd2 = rbf_mmd2(data, generated)
            if mmd2 < curr_MMD2:
                curr_MMD2 = mmd2
                curr_generated = generated
        
        return curr_generated.to(device), None, None

    else:
        print("Using KDE method for the parametric family.")
        # KDE method
        data = torch.cat([x for x, y in dataloader]).detach().cpu().numpy()

        from sklearn.neighbors import KernelDensity as gaussian_kde
        kernel = gaussian_kde().fit(data)
        generated = kernel.sample(max(len(data), size), random_state=1234)
        
        return torch.from_numpy(generated), None, None


from run_Ours_genIter import get_MMD_values_uneven

def get_MMD_values(D_Xs, D_Ys, V_X, V_Y, netD, sigma_list=[1, 2, 5, 10], device=torch.device('cuda')):
    results = []
    
    if netD:
        netD = netD.to(device)
        netD.eval()
    
    if isinstance(V_X, np.ndarray):
        V_X = torch.from_numpy(V_X)

    V_X = V_X.to(device)
    with torch.no_grad():
        for D_X in D_Xs:
            D_X = D_X.to(device)

            outputs = torch.cat([V_X, D_X], dim=0)

            output_real = outputs[:len(V_X)]
            output_fake = outputs[len(V_X):]

            MMD2 = rbf_mmd2(output_real, output_fake, sigma_list) # allow unequal size
            results.append(-torch.sqrt(max(1e-6, MMD2)).item())

    return results


from copy import deepcopy
from utils import cwd, set_deterministic, save_results

from data_utils import _get_loader
import numpy as np
import torch

from reg_data_utils import assign_data
from os.path import join as oj

from tqdm import tqdm
import argparse

from scipy.stats import pearsonr, spearmanr

baseline = 'Ours'

class options:
    cuda = True
    batch_size = 256
    n_filters = 16
    epochs = 50

if __name__ == '__main__':
    
    print(f"----- Running experiment for {baseline} -----")

    parser = argparse.ArgumentParser(description='Process which dataset to run for regression.')
    parser.add_argument('-N', '--N', help='Number if data venrods.', type=int, required=True, default=5)
    parser.add_argument('-m', '--size', help='Size of sample datasets.', type=int, required=True, default=1500)
    parser.add_argument('-P', '--dataset', help='Pick the dataset to run.', type=str, required=True)
    parser.add_argument('-Q', '--Q_dataset', help='Pick the Q dataset.', type=str, required=True, choices=['KingH', 'Census17'])
    parser.add_argument('-n_t', '--n_trials', help='Number of trials.', type=int, default=5)
    parser.add_argument('-gmm', dest='gmm', help='Whether to use GMM for generator distribution.', action='store_true')
    parser.add_argument('-kde', dest='gmm', help='Whether to use KDE for generator distribution.', action='store_false')

    parser.add_argument('-nocuda', dest='cuda', help='Not to use cuda even if available.', action='store_false')
    parser.add_argument('-cuda', dest='cuda', help='Use cuda if available.', action='store_true')

    cmd_args = parser.parse_args()
    print(cmd_args)

    dataset = cmd_args.dataset
    Q_dataset = cmd_args.Q_dataset
    N = cmd_args.N
    size = cmd_args.size
    n_trials = cmd_args.n_trials
    use_GMM = cmd_args.gmm
    cuda = cmd_args.cuda

    if torch.cuda.is_available() and cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    set_deterministic()

    values_over_trials, values_hat_over_trials = [], []
    values_hat_no_gen_over_trials = []
    values_mmd2_over_trials, values_hat_mmd2_over_trials = [], []
    for _ in tqdm(range(n_trials), desc =f'A total of {n_trials} trials.'):
        # raw data
        D_Xs, D_Ys, V_X, V_Y = assign_data(N, size, dataset, Q_dataset)

        loader = _get_loader(torch.cat(D_Xs), torch.cat(D_Ys), batch_size=options.batch_size)
        reference, netD, _ = get_reference_(loader, dataset=dataset, epochs=options.epochs, netD_output_dim=V_X.shape[1], device=device, use_GMM=use_GMM)

        # MMD_values = get_MMD_values(D_Xs, D_Ys, V_X, V_Y, netD, device=device)
        MMD_values = get_MMD_values_uneven(D_Xs, None, V_X, None, netD, device=device, batch_size=options.batch_size)
        print("MMD values:", MMD_values)
        values_over_trials.append(MMD_values)

        MMD2_values = get_MMD_values_uneven(D_Xs, None, V_X, None, netD, device=device, squared=True, batch_size=options.batch_size)
        values_mmd2_over_trials.append(MMD2_values)

        combined_reference = torch.cat([reference, torch.cat(D_Xs).to(device)]).float()
        # print(f" ---- combined reference shape: {combined_reference.shape} ---- ")
        MMD_values_hat = get_MMD_values_uneven(D_Xs, None, combined_reference, None, netD, device=device, batch_size=options.batch_size)
        values_hat_over_trials.append(MMD_values_hat)

        MMD_values_hat_no_gen = get_MMD_values_uneven(D_Xs, None, torch.cat(D_Xs).to(device), None, None, device=device, batch_size=options.batch_size)
        values_hat_no_gen_over_trials.append(MMD_values_hat_no_gen)

        MMD2_values_hat = get_MMD_values_uneven(D_Xs, None, combined_reference, None, netD, device=device, squared=True, batch_size=options.batch_size)
        values_hat_mmd2_over_trials.append(MMD2_values_hat)

    # Ours no gen
    results = {'values_over_trials': values_over_trials, 'values_hat_over_trials': values_hat_no_gen_over_trials, 
               'N':N, 'size':size, 'n_trials': n_trials, 'use_GMM': use_GMM}
    save_results(baseline=baseline, exp_name=oj('regression', f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials}'), **results)

    # Ours with gen only
    results = {'values_over_trials': values_hat_over_trials, 'values_hat_over_trials': values_hat_over_trials, 
               'N':N, 'size':size, 'n_trials': n_trials, 'use_GMM': use_GMM}
    baseline = baseline + "_GMM" if use_GMM else baseline +  '_KDE'
    save_results(baseline=baseline, exp_name=oj('regression', f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials}'), **results)


    # For MMD squared w.r.t. half mix reference
    results = {'values_over_trials': values_mmd2_over_trials, 'values_hat_over_trials': values_hat_mmd2_over_trials,
            'N':N, 'size':size, 'n_trials': n_trials, 'use_GMM': use_GMM}
    # save_results(baseline='MMD_sq_half_mix', exp_name=exp_name, **results)
    save_results(baseline='MMD_sq_half_mix', exp_name=oj('regression', f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials}'), **results)
